import torch
from surrogate import squared_loss, squared_hinge_loss
import numpy as np
import torch.nn as nn

class tpAUC_KL_loss(nn.Module):
    def __init__(self, pos_length, Lambda=1.0, tau=1.0, threshold=1.0, loss_type = 'sh'):
        '''
        param
        pos_length: number of positive examples for the training data
        num_neg: number of negative samples for each mini-batch
        threshold: margin for basic AUC loss
        Lambda: KL regularization for negative samples
        tau: KL regularization for positive samples
        gamma_0: stepsize for negative sample KL regularization term
        gamma_1: stepsize for positive sample KL regularization term
        loss type: basic AUC loss to apply.
        '''
        super(tpAUC_KL_loss, self).__init__()
        self.gamma_0 = 0.9
        self.gamma_1 = 0.9
        self.Lambda = Lambda
        self.tau = tau
        self.u_pos = torch.tensor([0.0]*pos_length).view(-1, 1).cuda()
        self.w = 0.0
        self.threshold = threshold
        if loss_type == 'sh':
          self.Loss = squared_hinge_loss
        elif loss_type == 'sq':
          self.Loss = squared_loss
        print('The loss type is :', loss_type)
    
    def set_gammas(self, gamma_0, gamma_1):
        self.gamma_0 = gamma_0
        self.gamma_1 = gamma_1
    def update_gammas(self, decay_factor):
        self.gamma_0 = self.gamma_0/decay_factor
        self.gamma_1 = self.gamma_1/decay_factor
    
    def forward(self, y_pred, y_true, index_p, index_n):
        f_ps = y_pred[y_true == 1].view(-1,1)
        f_ns = y_pred[y_true == 0].view(-1,1) 
        f_ps = f_ps.repeat(1,len(f_ns))
        f_ns = f_ns.repeat(1,len(f_ps))
        difference = f_ps - f_ns.transpose(0,1)
        
        loss = self.Loss(margin = self.threshold, t = difference) # before mean() operation.
        exp_loss = torch.exp(loss/self.Lambda).detach_()

        self.u_pos[index_p] = (1 - self.gamma_0) * self.u_pos[index_p] + self.gamma_0 * (exp_loss.mean(1, keepdim=True))

        self.w = (1 - self.gamma_1) * self.w + self.gamma_1 * (torch.pow(self.u_pos[index_p], self.Lambda/self.tau).mean())
        
        p = torch.pow(self.u_pos[index_p], self.Lambda/self.tau - 1) * exp_loss/self.w
        p.detach_()
        loss = torch.mean(p * loss)
        return loss




class tpAUC_CVaR_loss(nn.Module):
    def __init__(self, data_length, threshold=1.0, gamma=0.9, eta=1e-1, rate=0.5, momentum=0.0, device=None):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(tpAUC_CVaR_loss, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device   
        self.gamma = gamma
        self.eta_1 = eta
        self.eta_2 = eta
        self.data_length = data_length
        self.rate = rate
        self.u = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 
        self.s1 = torch.tensor([0.0]*data_length).view(-1, 1).to(self.device) 
        self.s2 = torch.tensor([0.0]).view(-1, 1).to(self.device) 
        self.threshold = threshold
        self.momentum = momentum
        if momentum > 0.0:
          self.comp_u = torch.tensor([1.0]*data_length).view(-1, 1).to(self.device) 

    def update_smoothing(self, decay_factor):
        self.gamma_1 = self.gamma_1/decay_factor
        self.eta_1 = self.eta_1/decay_factor
        self.gamma_2 = self.gamma_2/decay_factor
        self.eta_2 = self.eta_2/decay_factor

    def forward(self, y_pred, y_true, ids_p, ids_n): 
        
        v_p = y_pred[y_true==1].view(-1,1)
        v_n = y_pred[y_true==0].view(1,-1)
        mat_n = v_n.repeat(len(v_p), 1)
        loss = (torch.clip(self.threshold - (v_p - mat_n), min=0)**2)
        p1 = (loss.detach() > self.s1[ids_p]).float()
        p2 = (self.u[ids_p] > self.s2).float()
        tp_loss = (loss/self.rate*p1).mean(dim=-1,keepdim=True)
        tp_loss = (tp_loss/self.rate*p2).mean()
        
        
        
        comp_u = self.s1[ids_p] + torch.clip(loss.detach()-self.s1[ids_p], min=0)/self.rate
        comp_u = comp_u.mean(dim=-1, keepdim=True)
        momentum_comp_u = 0.0
        if self.momentum > 0:
          momentum_comp_u = self.momentum * (comp_u - self.comp_u[ids_p])
          self.comp_u[ids_p] = comp_u
        self.u[ids_p] = (1-self.gamma) * self.u[ids_p] + self.gamma * comp_u + momentum_comp_u
        self.s1[ids_p] -= self.eta_1/len(ids_p) *p2/self.rate* (1-p1.mean(dim=-1, keepdim=True)/self.rate)
        self.s2 -= self.eta_2 * (1-p2.mean()/self.rate)
        
        
        return tp_loss

